// SHRotationMatrix Class Definition
// ---------------------------------
//
// Takes a rotation matrix of the form:
//    (r[0], r[3], r[6])
//    (r[1], r[4], r[7])
//    (r[2], r[5], r[8])
// and an order.  Computes an order^2 x order^2 matrix.
//
#pragma once
//#include "../../common.hpp"
#include "SHCoeff.hpp"

namespace zzz{
template <int SHN,typename T>
class SHRotationMatrix {
public:
  // You can set the inMatrix outside and then call computeMatrix
  SHRotationMatrix(){}

  // Constructor.  Input the desired SH order and the original 3x3 transformation matrix
  SHRotationMatrix(T matrix[9], bool colmajor=true)
  {
    // copy the input matrix into local stroage.
    if (colmajor)
      memcpy(inMatrix,matrix,sizeof(T)*9);
    else
    {
      inMatrix[0]=matrix[0];  inMatrix[3]=matrix[1];  inMatrix[6]=matrix[2];
      inMatrix[1]=matrix[3];  inMatrix[4]=matrix[4];  inMatrix[7]=matrix[5];
      inMatrix[2]=matrix[6];  inMatrix[5]=matrix[7];  inMatrix[8]=matrix[8];
    }

    // actually compute the matrix.
    computeMatrix();
  }

  // Computes the order^2 x order^2 matrix
  void computeMatrix(void)
  {
    // initialize the matrix to 0's
    for (int i=0; i<order*order; i++)
      for (int j=0; j<order*order; j++)
        outMatrix[matIndex(i,j)] = 0;

    // 0th band {1x1 matrix} is the identity
    outMatrix[0] = 1;
    if (order < 2) return;

    // 1st band is a permutation of the 3D rotation matrix
    for (int count=0, i=-1; i<=1; i++)
      for (int j=-1; j<=1; j++)
        outMatrix[ matIndex((i+3)%3 + 1, (j+3)%3 + 1) ] = inMatrix[count++];

    // 2nd+ bands use a recurrance relation.
    for (int l=2; l<order; l++)
    {
      int ctr = l*(l+1);
      for (int n=-l; n<=l; n++)
        for (int m=-l; m<=l; m++)
          outMatrix[ matIndex(ctr + n, ctr + m) ] =
          u_i_st(l, m, n) * U_i_st(l, m, n) +
          v_i_st(l, m, n) * V_i_st(l, m, n) +
          w_i_st(l, m, n) * W_i_st(l, m, n);
    }
  }

  // Applies the order^2 x order^2 matrix to vector 'in', stores the result in the vector 'out'
  void applyMatrix(const SHCoeff<SHN,T> &insh, SHCoeff<SHN,T> &outsh)
  {
    const T *in=insh.v;
    T *out=outsh.v;
    // first band (order 0) is a 1x1 identity rotation matrix
    out[0] = in[0];

    // set up data for multiplying 2nd band (order 1) coefs
    int ord=1;
    int minIdx=1;
    int maxIdx=4;

    // multiply the rest of the matrix
    for (int idx=1; idx<order*order; idx++)
    {
      // multiply coefs from current band
      out[idx]=0;
      for (int j=minIdx; j<maxIdx; j++)
      out[idx] += outMatrix[ matIndex(j, idx) ] * in[j];

      // increase the band, reset indices.
      if (idx>=maxIdx-1)
      {
        ord++;
        minIdx=maxIdx;
        maxIdx+=2*ord+1;
      }
    }
  }
  T inMatrix[9];
private:
  static const int order=SHN;
  T outMatrix[SHN*SHN*SHN*SHN];

  // Compute a 1D index for (col,row) in the matrix 
  int matIndex(int col, int row)
  {
    return col*order*order+row;
  }

  // Computed as desribed in Table B.1
  T u_i_st (int i, int s, int t)
  {
    return sqrt(double((i+s)*(i-s) / (abs(t)==i ? 2*i*(2*i-1) : (i+t)*(i-t))));
  }
  T v_i_st (int i, int s, int t)
  {
    int delta = (s==0 ? 1 : 0);
    T factor = 0.5 * (1 - 2*delta);
    T numerator = (1+delta)*(i+abs(s)-1)*(i+abs(s));
    T denominator = (abs(t)==i ? 2*i*(2*i-1) : (i+t)*(i-t));
    return factor * sqrt(numerator / denominator);
  }

  T w_i_st (int i, int s, int t)
  {
    int delta = (s==0 ? 1 : 0);
    T factor = -0.5 * (1 - delta);
    T numerator = (i-abs(s)-1)*(i-abs(s));
    T denominator = (abs(t)==i ? 2*i*(2*i-1) : (i+t)*(i-t));
    return factor * sqrt(numerator / denominator);
  }


  // Computed as described in Table B.2
  T U_i_st (int i, int s, int t)
  {
    return P_r_i_st(0,i,s,t);
  }
  T V_i_st (int i, int s, int t)
  {
    int delta = (abs(s)==1 ? 1 : 0);
    if (s == 0)
      return P_r_i_st(1,i,1,t) + P_r_i_st(-1,i,-1,t);
    if (s > 0) 
      return 
      sqrt(1.0+delta) * P_r_i_st(1,i,s-1,t) - (1-delta) * P_r_i_st(-1,i,-s+1,t);
    return
      (1-delta) * P_r_i_st(1,i,s+1,t) + sqrt(1.0+delta) * P_r_i_st(-1,i,-s-1,t);
  }

  T W_i_st (int i, int s, int t)
  {
    if (s==0) return 0;
    if (s > 0) return P_r_i_st(1,i,s+1,t) + P_r_i_st(-1,i,-s-1,t);
    return P_r_i_st(1,i,s-1,t) - P_r_i_st(-1,i,-s+1,t);
  }

  // Computed as described in Table B.3
  T P_r_i_st (int r, int i, int s, int t)
  {
    if (abs(t) < i) return R(r,0)*M(i-1,s,t);
    if (t == i) return R(r,1)*M(i-1,s,i-1) - R(r,-1)*M(i-1,s,-i+1);
    return R(r,1)*M(i-1,s,-i+1) + R(r,-1)*M(i-1,s,i-1);
  }

  // Index into the input matrix for -1 <= i,j <= 1, as per Equation B.40
  T R (int i, int j)
  {
    int jp = ((j+2) % 3);     // 0 <= jp < 3 
    int ip = ((i+2) % 3);     // 0 <= ip < 3
    return inMatrix[jp*3+ip]; // index into input matrix
  }

  // Index into band l, element (a,b) of the result (-l <= a,b <= l)
  T M(int l, int a, int b)
  {
    if (l<=0) return outMatrix[0];
    // Find the center of band l (outMatrix[ctr,ctr])
    int ctr = l*(l+1);
    return outMatrix[ matIndex(ctr + b, ctr + a) ];
  }
};
}